import gfootball.env as football_env
from gfootball.env import observation_preprocessing
import gym
import numpy as np
import matplotlib.pyplot as plt
from smac.env import StarCraft2Env

class StarCraftMultiAgentEnv(object):
    """An wrapper for GFootball to make it compatible with our codebase."""

    def __init__(self,  args):
        self.time_step = 0


        self.env = StarCraft2Env(map_name=args.env,
                            step_mul=args.step_mul,
                                 seed=args.seed,
                            difficulty=args.difficulty,
                            game_version=args.game_version,
                            replay_dir=args.replay_dir)

        self.obs_dim = self.get_obs_size()  # for counterattack_easy 4 vs 2
        self.state_shape = self.get_state_size()
        self.n_actions=self.env.n_actions
        self.n_enemy=self.env.n_enemies
        self.time_limit = self.env.episode_limit
        self.n_agents = self.env.n_agents #MM2 10 agents
        self.p_state=self.n_agents
        self.sight_range=self.env.unit_sight_range(1)
    def get_obs_agent(self,id):

        return self.env.get_obs_agent(id)

    def get_p_state(self):
        state_dict=self.env.get_state_dict()
        ally_state=state_dict["allies"]
        center_pos=ally_state[:,2:4]*np.array([self.env.max_distance_x ,self.env.max_distance_y ])
        pos_repeat=center_pos[:,np.newaxis].repeat(self.n_agents,axis=-2).reshape(self.n_agents,self.n_agents,2)
        dst=np.linalg.norm(pos_repeat-pos_repeat.swapaxes(0,1),axis=-1)
        dead_label=ally_state[:,0].astype('bool').astype('float32')
        dead_matrix=dead_label.reshape(1,self.n_agents)*dead_label.reshape(self.n_agents,1)
        visible_mask=(dst<self.sight_range).astype('float32')*dead_matrix

        #new_ally_state[:, 4:] = ally_state[:, 4:]#location(2) health, energy

        # full_obs = self.env.unwrapped.observation()[0]
        # obs_left_team_pos = full_obs['left_team'][-self.n_agents:]
        # obs_left_team_direction = full_obs['left_team_direction'][-self.n_agents:]
        #
        # ball_relative = full_obs['ball'][:2] - full_obs['left_team'][-self.n_agents:]
        # ball_relative_dst = np.linalg.norm(ball_relative, axis=-1)
        # ball_theta = np.arctan2(ball_relative[:, 1], ball_relative[:, 0])
        # ball_rel = np.concatenate((ball_relative, ball_relative_dst[:, np.newaxis], ball_theta[:, np.newaxis]), axis=-1)
        # p_state = np.concatenate((obs_left_team_pos, obs_left_team_direction, ball_rel), axis=-1)
        return visible_mask




    def reset(self):
        self.time_step = 0

        return self.env.reset()


    def step(self, actions):

        self.time_step += 1
        return self.env.step(actions)


    def close(self):
        self.env.close()

    def get_state(self):
        return self.env.get_state()

    def get_obs(self):
        """ Returns all agent observations in a list """

        return self.env.get_obs()

    def get_obs_size(self):
        """ Returns the shape of the observation """

        return self.env.get_obs_size()

    def get_state_size(self):
        """ Returns the shape of the state"""

        return self.env.get_state_size()

    def get_avail_agent_actions(self, id):
        return self.env.get_avail_actions()


    def get_env_info(self):
        ally_num,ally_size=self.env.get_obs_ally_feats_size()
        output_dict = {}
        output_dict['n_actions'] = self.n_actions
        output_dict['obs_shape'] = self.obs_dim
        output_dict['n_agents'] = self.n_agents
        output_dict['state_shape'] = self.state_shape
        output_dict['episode_limit'] = self.time_limit
        output_dict['n_enemy'] = self.n_enemy
        output_dict['p_state'] = self.p_state
        output_dict['reltive_loc_dim']=2
        output_dict["ally_feats_dim"] = ally_num*ally_size
        output_dict["own_feats_dim"] = self.env.get_obs_own_feats_size()
        output_dict["visual_r"] = 1


        return output_dict


# def make_football_env(seed_dir, dump_freq=1000, representation='extracted', render=False):
def make_StarCraft_MM2(args):
    '''
    Creates a env object. This can be used similar to a gym
    environment by calling env.reset() and env.step().

    Some useful env properties:
        .observation_space  :   Returns the observation space for each agent
        .action_space       :   Returns the action space for each agent
        .nagents            :   Returns the number of Agents
    '''
    return StarCraftMultiAgentEnv( args)
